import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import matplotlib.pyplot as plt
import os



class SnakeEnv(gym.Env):
    def __init__(self, render=False, snake=None, board_x_size=None, board_y_size=None):
        super(SnakeEnv, self).__init__()
        # controls the maximum board size. This is a constant for this class
        self.MAX_BOARD = 20
        self.should_render = render
        self.max_moves = int(self.MAX_BOARD**2 + 5)
        self.mov_closer_rew_frac = 1 / self.max_moves
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Box(low=0, high=1, shape=(1, self.MAX_BOARD**2 + 1 + 4 + 2 + 1), dtype=np.float32)
        if self.should_render:
            plt.ion()
            self.fig, self.ax = plt.subplots()
            self.full_color = np.array(255, dtype=np.uint8)

        self.board_x_size = board_x_size
        self.board_y_size = board_y_size
        self.norm_mul = np.array([[2],[3]])
        self.img = np.zeros((self.MAX_BOARD, self.MAX_BOARD), dtype=np.int16)
        self.custom_snake = snake
        self.snake_body = None
        self.body_positions_set = None
        self.remaining_moves = np.array([self.max_moves], dtype=np.int16)
        self.directions = [
            np.array([-1, 0], dtype=np.int8),  # Left
            np.array([1, 0], dtype=np.int8),   # Right
            np.array([0, 1], dtype=np.int8),  # Down
            np.array([0, -1], dtype=np.int8)    # Up
        ]
        self.max_len = self.MAX_BOARD**2
        self.snake_numbers = np.arange(3, self.max_len + 3, dtype=np.int16)
        
        
    def step(self, action):
        self.reward = 0
        # Change the head position based on the button direction
        self.snake_head += self.directions[action]
        col_with_apple = False
        snake_tail = np.copy(self.snake_position[-1])
        # Increase Snake length on eating apple
        if np.array_equal(self.snake_head, self.apple_position):
            self.snake_position = np.append([self.snake_head], self.snake_position, axis=0)
            self.snake_head = np.copy(self.snake_position[0])
            if len(self.snake_position) == self.board_max_len:
                self.done = True
            else:
                self.apple_position, self.score = self.collision_with_apple(
                    self.score, self.current_board_size[0] - 1, self.current_board_size[1] - 1
                )
            col_with_apple = True
        else:
            self.snake_position[1:] = self.snake_position[:-1]
            self.snake_position[0] = self.snake_head.copy()


        # On collision kill the snake
        if (
            self.collision_with_boundaries(self.snake_head) == 1
            or self.collision_with_self(self.snake_position) == 1
        ):
            self.done = True

        # Display Snake
        if not self.done or col_with_apple:
            self.update_snake_pos_in_img(self.img, snake_tail, ate_apple=col_with_apple)

            # self.draw_snake(self.img, self.snake_position)
        else:
            self.update_snake_pos_in_img(self.img, snake_tail, done=True, ate_apple=col_with_apple)
            # self.draw_snake(self.img, self.snake_position, done=True)

        # Display Apple
        if not self.done and col_with_apple:
            self.img[self.apple_position[1], self.apple_position[0]] = 2

        dists_neg = self.snake_head - self.apple_position
        self.direction_to_apple = np.sign(dists_neg) + 1
        dists = np.abs(dists_neg.copy())
        past_dists = np.abs((self.snake_head - self.directions[action]) - self.apple_position)
        reward_for_eating_apple = 2
        if self.done and not col_with_apple:
            self.reward = -100
        elif col_with_apple:
            
            self.reward = reward_for_eating_apple * (1 - self.mov_closer_rew_frac * self.steps_taken_closer)
            self.steps_taken_closer = 0
        
        elif dists[0] < past_dists[0]:
            self.reward = reward_for_eating_apple * self.mov_closer_rew_frac * 0.75
            self.steps_taken_closer += 1
        elif dists[1] < past_dists[1]:
            self.reward = reward_for_eating_apple * self.mov_closer_rew_frac * 0.75
            self.steps_taken_closer += 1
        else:
            self.reward = 0

        self.remaining_moves[0] -= 1
        if not self.done and self.remaining_moves[0] == 0:
            self.done = True
            self.reward = -100
        elif col_with_apple:
            self.remaining_moves[0] = self.max_moves
        
        
        if self.should_render:
            self.render(img=self.img)
        self.observation = self.flatten_obs(self.img, self.remaining_moves, self.safe_moves_array(self.snake_head, self.snake_position), self.direction_to_apple, self.steps_taken_closer)
        info = {}
        return (
            self.observation,
            self.reward,
            self.done,
            False,
            info,
        )

    def reset(self, seed=None):
        self.done = False
        self.snake_body = None
        self.body_positions_set = None
        if self.board_x_size is None:
            board_x_size = random.randint(4, self.MAX_BOARD)
        else:
            board_x_size = self.board_x_size
        if self.board_y_size is None:
            board_y_size = random.randint(4, self.MAX_BOARD)
        else:
            board_y_size = self.board_y_size
        self.board_max_len = board_x_size * board_y_size
        self.current_board_size = np.array([board_x_size, board_y_size], dtype=np.int16)
        self.img.fill(0)
        self.img[board_y_size:, :] = 1
        self.img[:, board_x_size:] = 1
        self.remaining_moves[0] = self.max_moves
        # Initial Snake and Apple position
        self.snake_position = self.custom_snake.copy() if self.custom_snake is not None else self.random_start_position(board_x_size-1, board_y_size-1)
        self.apple_position = self.collision_with_apple(0, board_x_size - 1, board_y_size - 1)[0]

        # Display Apples
        self.img[self.apple_position[1], self.apple_position[0]] = 2
        # Display Snake
        self.draw_snake(self.img, self.snake_position)
        self.snake_head = np.copy(self.snake_position[0])
        self.mins_neg = self.snake_head - self.apple_position
        self.direction_to_apple = np.sign(self.mins_neg) + 1

        self.score = 0
        self.reward = 0
        self.steps_taken_closer = 0
        if self.should_render:
            self.render(img=self.img)
        self.observation = self.flatten_obs(self.img, self.remaining_moves, self.safe_moves_array(self.snake_head, self.snake_position), self.direction_to_apple, self.steps_taken_closer)
        info = {}
        return (
            self.observation,
            info,
        )  # reward, done, info can't be included
    
    # flattens obs into a 1d array and normalizes it
    def flatten_obs(self, image, remaining_moves, safe_moves, direction_to_apple, steps_taken_closer):
        image = np.float32(image.flatten()) / (self.max_len + 3)
        remaining_moves = np.float32(remaining_moves) / self.max_moves
        safe_moves = np.float32(safe_moves)
        direction_to_apple = np.float32(direction_to_apple) / 2
        steps_taken_closer = np.float32(steps_taken_closer) / self.max_moves
        return np.hstack((image, remaining_moves, safe_moves, direction_to_apple, steps_taken_closer))
        
    def draw_snake(self, img, snake_position, done = False):
        # uses the snake numbers to draw the snake. with the head being 1 and the body being numberd up to the max length + 1 which is the tail
        if not done:
            img[snake_position[0][1], snake_position[0][0]] = self.snake_numbers[0]
        if (len(snake_position) > 1):
            img[snake_position[1:, 1], snake_position[1:, 0]] = self.snake_numbers[-len(snake_position) + 1:]
        #print(self.snake_numbers[-len(snake_position) + 1:])
        #print(img)





    # renders the game with matplotlib
    def render(self, img, mode="human"):
        self.ax.clear()
        # cp = np.zeros((self.board_size, self.board_size, 3), dtype=np.uint8)
        # # snake_colors = np.arange(0, 256, 255 / (len(self.snake_position) - 1), dtype=np.uint8)
        # snake_colors = np.arange(0, 255.1, 255 / (len(self.snake_position) - 1), dtype=np.float16)
        # snake_colors = snake_colors.astype(np.uint8)
        # if not self.done:
        #     cp[self.snake_position[:, 1], self.snake_position[:, 0], 1] = 128
        #     cp[self.snake_position[:, 1], self.snake_position[:, 0], 2] = snake_colors 
        # else:
        #     cp[self.snake_position[1:, 1], self.snake_position[1:, 0], 1] = 128
        #     cp[self.snake_position[1:, 1], self.snake_position[1:, 0], 2] = snake_colors[1:]
        # cp[self.apple_position[1], self.apple_position[0], 0] = self.full_color
        rendered_img = np.zeros((self.MAX_BOARD, self.MAX_BOARD, 3), dtype=np.uint8)
        rendered_img[img == 1, 2] = 255
        rendered_img[img == 2, 0] = 255
        rendered_img[img >= 3, 1] = (img[img >= 3] / 403 * 254) + 1
        rendered_img[img >= 3, 2] = 127
        self.ax.imshow(rendered_img)
        self.fig.canvas.draw()
        plt.pause(0.01)

    # updates apple position and score when snake collides with apple
    def collision_with_apple(self, score, max_x, max_y):
        apple_x_position = random.randint(0, max_x)
        apple_y_position = random.randint(0, max_y)
        body_positions_set = self.get_body_positions_set(self.snake_position)
        while (apple_x_position, apple_y_position) in body_positions_set:
            apple_x_position = random.randint(0, max_x)
            apple_y_position = random.randint(0, max_y)
        apple_position = np.array([apple_x_position, apple_y_position], dtype=np.int16)
        score += 1
        return apple_position, score

    # checks if the snake has collided with the boundaries
    def collision_with_boundaries(self, snake_head):
        snake_head_x = snake_head[0]
        snake_head_y = snake_head[1]
        if (
            snake_head_x >= self.current_board_size[0]
            or snake_head_x < 0
            or snake_head_y >= self.current_board_size[1]
            or snake_head_y < 0
        ):
            return 1
        else:
            return 0

    # checks if the snake has collided with itself
    def collision_with_self(self, snake_position):
        body_positions_set = self.get_body_positions_set(snake_position)
        # if the length of the set is not equal to the length of the snake, then it must have tried adding a duplicate, so it collided with itself
        if len(body_positions_set) != len(snake_position):
            return 1
        else:
            return 0
        
    # gets the body positions as a set
    def get_body_positions_set(self, snake_position):
        # if none, goes through all the positions and adds them to the set
        if self.body_positions_set is None:
            self.body_positions_set = set(tuple(pos) for pos in snake_position)
            self.last_tail_tuple = tuple(snake_position[-1])
        elif len(self.body_positions_set) != len(snake_position):
            # if the length of the set is not equal to the length of the snake, then it must have grown
            tuple_snake_head = tuple(snake_position[0])
            self.body_positions_set.add(tuple_snake_head)
        elif self.last_tail_tuple != tuple(snake_position[-1]):
            # if the last tail tuple is not equal to the last position of the snake, then it must have moved
            self.body_positions_set.remove(self.last_tail_tuple)
            tuple_snake_head = tuple(snake_position[0])
            self.body_positions_set.add(tuple_snake_head)
            self.last_tail_tuple = tuple(snake_position[-1])

        return self.body_positions_set
    
    # gets the safe moves as a 1d array
    # 1 is safe, 0 is not safe
    def safe_moves_array(self, snake_head, snake_position):
        body_positions_set = self.get_body_positions_set(snake_position)
        snake_head_x = snake_head[0]
        snake_head_y = snake_head[1]
        tail_tuple = tuple(snake_position[-1])
        safe_moves = np.ones((4,), dtype=np.int8)
        move_left_tuple = (snake_head_x - 1, snake_head_y)
        move_right_tuple = (snake_head_x + 1, snake_head_y)
        move_up_tuple = (snake_head_x,  snake_head_y - 1)
        move_down_tuple = (snake_head_x, snake_head_y + 1)
        # checks tail tuple, becuase it is safe to move into the tail
        if (move_left_tuple in body_positions_set or move_left_tuple[0] < 0) and move_left_tuple != tail_tuple:
            safe_moves[0] = 0
        if (move_right_tuple in body_positions_set or move_right_tuple[0] >= self.current_board_size[0]) and move_right_tuple != tail_tuple:
            safe_moves[1] = 0
        if (move_up_tuple in body_positions_set or move_up_tuple[1] < 0) and move_up_tuple != tail_tuple:
            safe_moves[2] = 0
        if (move_down_tuple in body_positions_set or move_down_tuple[1] >= self.current_board_size[1]) and move_down_tuple != tail_tuple:
            safe_moves[3] = 0
        # left, right, up, down
        return safe_moves
    
    # gets a random starting position for the snake that is at least 3 long using the hamiltonian path to prevent overfitting
    def random_start_position(self, max_x, max_y):
        rand_x_start = random.randint(0, max_x)
        rand_y_start = random.randint(0, max_y)
        snake = np.array([[rand_x_start, rand_y_start]], dtype=np.int16)
        return snake


    def update_snake_pos_in_img(self, img, past_tail_pos, done=False, ate_apple = False):
        

        if len(self.snake_position) > 1 and not ate_apple:
            img[self.snake_position[1][1], self.snake_position[1][0]] = (self.max_len + 3) - len(self.snake_position) + 1
            img[self.snake_position[1:, 1], self.snake_position[1:, 0]] += 1
            img[past_tail_pos[1], past_tail_pos[0]] = 0
        elif len(self.snake_position) > 1 and ate_apple:
            img[self.snake_position[1][1], self.snake_position[1][0]] = (self.max_len + 3) - len(self.snake_position) + 2
        else:
            img[past_tail_pos[1], past_tail_pos[0]] = 0

        if not done:
            img[self.snake_position[0][1], self.snake_position[0][0]] = 3